2 万 Star!屡屡斩获 Kaggle 各大竞赛冠军宝座的利器
一、简介
XGBoost (eXtreme Gradient Boosting) 是“极端梯度提升”的简称,它是为满足高效、灵活和可移植性的目的而“诞生”,也被称为优化的分布式增强库。支持 Python、Scala、R、Julia、C++ 等语言。
XGBoost 所应用的算法是机器学习算法—GBDT(gradient boosting decision tree)的改进,既可以用于分类也可以用于回归问题中。
XGBoost 提供了一种并行树增强(也被称为GBDT、GBM),可以快速准确地解决许多数据科学问题。相同的代码可以运行在目前主流的分布式环境(Kubernetes、Hadoop、SGE、MPI、Dask),并且可以解决超过数十亿个样例计算的问题。
XGBoost 可以处理回归、分类和排序等多种任务。由于它在预测性能上的强大且训练速度快,XGBoost 已屡屡斩获 Kaggle 各大竞赛的冠军宝座。
二、开源主页
https://github.com/dmlc/xgboost
XGBoost 在 GitHub 已获得 20.4 k Star。
三、安装和使用案例
3.1、下载(windows10 64位,python 3.7.3)
打开 cmd 命令模式,输入 pip install xgboost
但网络需要架梯子(不知道是不是我的网不好),我不架梯子会因网络问题出现以下报错:
架梯子后,虽然速度慢了点,但能成功安装。
2、如果没有梯子,可以从这个网址(https://www.lfd.uci.edu/~gohlke/pythonlibs/#xgboost) ,找到与自己 Python 版本
xgboost
然后重新执行 pip install xgboost-0.90-cp37-cp37m-win_amd64.whl 可以安装成功
3.2、使用示例
在这个例子中,将使用墨尔本房屋数据集进行简单的测试使用——预测房价。
数据集下载方式:请关注微信公号
开源前哨
,然后发送xgboost
即可获取。
在使用这数据之前,需要对三种常见的数值问题进行处理。
1、缺失值的处理(通常处理方法是直接删除缺少值的列或利用该列均值填充);
2、异常值的处理;
3、重复值的处理等;
在这里我们不会关注数据的处理加载步骤(请自行学习),假设已经拥有(我处理过了)了X_train(样本训练集)、X_valid(样本验证集)、y_train(标签训练集)和y_valid(标签验证集)的训练和验证数据。
下面是用 jupyter notebook 进行的代码执行的代码:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
# 装载数据
data = pd.read_csv('E:\数据集\melb_data.csv')
#选择测试集列和预测目标列
cols_to_use = ['Rooms', 'Distance', 'Landsize', 'BuildingArea', 'YearBuilt']
X = data[cols_to_use]
y = data.Price
#将数据分为训练集和测试集
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
#本例使用XGBoost库
#先导入该库,
from xgboost import XGBRegressor
xg_model = XGBRegressor()
#利用训练集训练模型 语句简单
xg_model.fit(X_train,y_train)
#本例使用XGBoost库
#先导入该库,
from xgboost import XGBRegressor
xg_model = XGBRegressor()
#利用训练集训练模型 语句简单
xg_model.fit(X_train,y_train)
#Out
'''XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,
colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
importance_type='gain', interaction_constraints='',
learning_rate=0.300000012, max_delta_step=0, max_depth=6,
min_child_weight=1, missing=nan, monotone_constraints='()',
n_estimators=100, n_jobs=12, num_parallel_tree=1,
objective='reg:squarederror', random_state=0, reg_alpha=0,
reg_lambda=1, scale_pos_weight=1, subsample=1, tree_method='exact',
validate_parameters=1, verbosity=None)'''
#先对模型进行预测和用测试集进行评估,得出绝对误差
from sklearn.metrics import mean_absolute_error
predictions = xg_model.predict(X_valid)
print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid)))
#Out
'''Mean Absolute Error: 235552.95046713916'''
#XGBoost有几个参数对训练结果比较重要,改变参数值,查看评估结果
#learning_rate——学习率(默认是0.1),学习率的出现可以很好的解决过拟合问题,我们改为0.15和0.05看一下不同结果
xg_model = XGBRegressor(learning_rate=0.05)
xg_model.fit(X_train, y_train)
#out
'''XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,
colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
importance_type='gain', interaction_constraints='',
learning_rate=0.05, max_delta_step=0, max_depth=6,
min_child_weight=1, missing=nan, monotone_constraints='()',
n_estimators=100, n_jobs=12, num_parallel_tree=1,
objective='reg:squarederror', random_state=0, reg_alpha=0,
reg_lambda=1, scale_pos_weight=1, subsample=1, tree_method='exact',
validate_parameters=1, verbosity=None)'''
predictions = xg_model.predict(X_valid)
print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid)))
#out
'''Mean Absolute Error: 248468.6911450663'''
xg_model = XGBRegressor(learning_rate=0.15)
xg_model.fit(X_train, y_train)
predictions = xg_model.predict(X_valid)
print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid)))
#out
'''Mean Absolute Error: 234624.95226435937'''
#n_estimators参数,指的是弱估计器的数量,即树的个数,太低容易导致欠拟合(对训练集和测试集的训练结果误差都很大),过高容易导致过拟合(只对训练集效果很好,但对测试集效果差)
xg_model = XGBRegressor(n_estimators=500)
xg_model.fit(X_train, y_train)
predictions = xg_model.predict(X_valid)
print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid)))
#out
'''Mean Absolute Error: 247669.14656434095'''
xg_model = XGBRegressor(n_estimators=1000)
xg_model.fit(X_train, y_train)
predictions = xg_model.predict(X_valid)
print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid)))
#Out
'''Mean Absolute Error: 255654.5402660162'''
#early_stopping_rounds参数,该参数的意思是早点让训练结束,因为在迭代次数到了一定数量时,训练误差会在一个值范围内波动,甚至出现下降的
#的现象,这样就会出现过拟合现象,early_stopping_rounds参数一般设置在40左右,意思是当迭代40轮后,训练误差若出现上升的现象,便提前终止训练。
xg_model = XGBRegressor()
xg_model.fit(X_train, y_train,
early_stopping_rounds=5,eval_set=[(X_valid, y_valid)])
predictions = xg_model.predict(X_valid)
print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid)))
#Out
'''[0] validation_0-rmse:941935.18750
[1] validation_0-rmse:732831.68750
[2] validation_0-rmse:599871.62500
[3] validation_0-rmse:516429.90625
[4] validation_0-rmse:467964.43750
[5] validation_0-rmse:441869.50000
[6] validation_0-rmse:426009.87500
[7] validation_0-rmse:416771.93750
[8] validation_0-rmse:411553.68750
[9] validation_0-rmse:408549.15625
[10] validation_0-rmse:406357.93750
[11] validation_0-rmse:403870.65625
[12] validation_0-rmse:402537.96875
[13] validation_0-rmse:402280.25000
[14] validation_0-rmse:400586.65625
[15] validation_0-rmse:399610.59375
[16] validation_0-rmse:398340.71875
[17] validation_0-rmse:397867.90625
[18] validation_0-rmse:397690.25000
[19] validation_0-rmse:397726.62500
[20] validation_0-rmse:396976.93750
[21] validation_0-rmse:396865.03125
[22] validation_0-rmse:395752.12500
[23] validation_0-rmse:392401.65625
[24] validation_0-rmse:393190.93750
[25] validation_0-rmse:392763.21875
[26] validation_0-rmse:392791.09375
[27] validation_0-rmse:391552.53125
[28] validation_0-rmse:391745.37500
[29] validation_0-rmse:391624.40625
[30] validation_0-rmse:391096.18750
[31] validation_0-rmse:391777.71875
[32] validation_0-rmse:392427.34375
[33] validation_0-rmse:391748.43750
[34] validation_0-rmse:391423.40625
[35] validation_0-rmse:391501.90625
Mean Absolute Error: 243436.07564893225'''
xg_model = XGBRegressor()
xg_model.fit(X_train, y_train,
early_stopping_rounds=30,eval_set=[(X_valid, y_valid)])
predictions = xg_model.predict(X_valid)
print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid)))
#Out
'''[0] validation_0-rmse:941935.18750
[1] validation_0-rmse:732831.68750
[2] validation_0-rmse:599871.62500
[3] validation_0-rmse:516429.90625
[4] validation_0-rmse:467964.43750
[5] validation_0-rmse:441869.50000
[6] validation_0-rmse:426009.87500
[7] validation_0-rmse:416771.93750
[8] validation_0-rmse:411553.68750
[9] validation_0-rmse:408549.15625
[10] validation_0-rmse:406357.93750
[11] validation_0-rmse:403870.65625
[12] validation_0-rmse:402537.96875
[13] validation_0-rmse:402280.25000
[14] validation_0-rmse:400586.65625
[15] validation_0-rmse:399610.59375
[16] validation_0-rmse:398340.71875
[17] validation_0-rmse:397867.90625
[18] validation_0-rmse:397690.25000
[19] validation_0-rmse:397726.62500
[20] validation_0-rmse:396976.93750
[21] validation_0-rmse:396865.03125
[22] validation_0-rmse:395752.12500
[23] validation_0-rmse:392401.65625
[24] validation_0-rmse:393190.93750
[25] validation_0-rmse:392763.21875
[26] validation_0-rmse:392791.09375
[27] validation_0-rmse:391552.53125
[28] validation_0-rmse:391745.37500
[29] validation_0-rmse:391624.40625
[30] validation_0-rmse:391096.18750
[31] validation_0-rmse:391777.71875
[32] validation_0-rmse:392427.34375
[33] validation_0-rmse:391748.43750
[34] validation_0-rmse:391423.40625
[35] validation_0-rmse:391501.90625
[36] validation_0-rmse:390788.90625
[37] validation_0-rmse:390342.50000
[38] validation_0-rmse:388972.56250
[39] validation_0-rmse:388548.37500
[40] validation_0-rmse:389271.59375
[41] validation_0-rmse:388126.15625
[42] validation_0-rmse:387813.06250
[43] validation_0-rmse:387755.25000
[44] validation_0-rmse:388325.31250
[45] validation_0-rmse:388446.96875
[46] validation_0-rmse:388394.09375
[47] validation_0-rmse:388425.00000
[48] validation_0-rmse:388244.31250
[49] validation_0-rmse:387900.65625
[50] validation_0-rmse:387721.56250
[51] validation_0-rmse:387184.31250
[52] validation_0-rmse:386687.31250
[53] validation_0-rmse:386059.87500
[54] validation_0-rmse:386105.56250
[55] validation_0-rmse:386133.31250
[56] validation_0-rmse:385496.62500
[57] validation_0-rmse:385332.71875
[58] validation_0-rmse:385390.18750
[59] validation_0-rmse:385281.46875
[60] validation_0-rmse:385243.71875
[61] validation_0-rmse:385267.50000
[62] validation_0-rmse:385012.65625
[63] validation_0-rmse:385141.46875
[64] validation_0-rmse:384997.34375
[65] validation_0-rmse:385355.09375
[66] validation_0-rmse:385625.09375
[67] validation_0-rmse:385546.90625
[68] validation_0-rmse:385723.62500
[69] validation_0-rmse:385636.68750
[70] validation_0-rmse:385617.34375
[71] validation_0-rmse:385682.62500
[72] validation_0-rmse:385741.12500
[73] validation_0-rmse:385583.09375
[74] validation_0-rmse:385650.28125
[75] validation_0-rmse:385895.87500
[76] validation_0-rmse:385654.53125
[77] validation_0-rmse:385794.40625
[78] validation_0-rmse:385448.96875
[79] validation_0-rmse:385472.40625
[80] validation_0-rmse:385382.06250
[81] validation_0-rmse:385556.25000
[82] validation_0-rmse:385969.62500
[83] validation_0-rmse:385744.06250
[84] validation_0-rmse:385602.81250
[85] validation_0-rmse:385691.06250
[86] validation_0-rmse:385536.59375
[87] validation_0-rmse:385594.53125
[88] validation_0-rmse:385643.12500
[89] validation_0-rmse:385773.93750
[90] validation_0-rmse:385573.40625
[91] validation_0-rmse:385784.81250
[92] validation_0-rmse:385986.00000
[93] validation_0-rmse:386088.37500
Mean Absolute Error: 237037.17186349412'''
以上便是这次 Xgboost 的安装及简单实用,大家可以根据自己训练模型的需要,去学习它更多的 API 方法。
- EOF -
更多优秀开源项目(点击下方图片可跳转)
开源前哨
日常分享热门、有趣和实用的开源项目。参与维护10万+star的开源技术资源库,包括:Python, Java, C/C++, Go, JS, CSS, Node.js, PHP, .NET 等
关注后获取
回复 资源 获取10万+star开源资源
分享、点赞和在看
支持我们分享更多优秀开源项目,谢谢!